Tutorial ML 2
A common task in BCI research is to test a machine learning model (MLM) on a large amount of real data. This tutorial uses the FII BCI corpus to carry out such a task.
If you did not download the corpus yet, do so before running this tutorial using the downloadDB function.
The tutorial shows how to
Select databases and sessions from the FII BCI Corpus according to:
- BCI Paradigm (Motor Imagery for this example)
- availability of specific classes
- minimum number of trials per class
Run a cross-validation for all selected sessions in all selected databases and store the balanced accuracies obtained on all cross-validations
Compute the average balanced accuracy within each database using an appropriate weighting function.
As a MLM, the MDM Riemannian classifier employing the affine-invariant (Fisher-Rao) metric is used (Barachant et al., 2012). As a covariance matrix estimator, the linear shrinkage estimator of (Ledoit and Wolf, 2004) is used. These are state-of-the art settings used as default in Eegle.
For each session, an 8-fold stratified cross-validation is run. While doing computations, summary results per session will be printed, including the mean and standard deviation of the balanced accuracy obtained across the folds as well as the p-value of the cross-validation test-statistic.
Tell julia to use Eegle
using EegleSelect all motor imagery databases in the FII BCI Corpus featuring the "feet" and "right_hand" class. Within these databases, select the sessions featuring at least 30 trials for each of these classes โ see selectDB.
classes = ["feet", "right_hand"];
DBs = selectDB(:MI; classes, minTrials = 30);Create memory to store all accuracies.
MIacc = [zeros(length(DB.files)) for DB โ DBs];Perform the cross-validation on all selected sessions for all selected databases:
for (db, DB) โ enumerate(DBs), (f, file) โ enumerate(DB.files)
# perform cross-validation (using Eegle)
cv = crval(file; upperLimit = 1.2, bandPass=(8, 32), classes)
# store accuracy
MIacc[db][f] = cv.avgAcc
# print a summary of the cv results
println("\nDatabase ", DB.dbName,"-", DB.condition, ", File ", f,
": mean(sd) balanced accuracy ", round(cv.avgAcc*100, digits=2),
"% (ยฑ ", round(cv.stdAcc*100, digits=2), "%); ",
"p-value ", round(cv.p; digits = 4))
endShow all MI accuracies
allMIacc = [round.(db; digits=3) for db โ MIacc]Create appropriate weights to average the balanced accuracy within each database using the weightsDB function and compute the weighted average balanced accuracy within each database.
MIw = [weightsDB(db.files)[1] for db โ DBs]; # get weights
MIw = [v ./= mean(v) for v โ w]; # normalize to unit mean
MIdbAcc = [mean(w.*acc) for (w, acc) โ zip(MIw, allMIacc)]The output you will see:
7-element Vector{Float64}:
0.8011111111111111
0.8283333333333331
0.762857142857143
0.830615454071332
0.8523076923076922
0.7055555555555555
0.8743571862516337For all possible options in running cross-validations, see Eegle.BCI.crval.
Do not use Julia's @threaded macro in the internal loops above as by default function crval is already multi-threaded across folds.
Code for Tutorial ML 2
using Eegle
classes = ["feet", "right_hand"];
DBs = selectDB(:MI; classes, minTrials = 30);
MIacc = [zeros(length(DB.files)) for DB โ DBs];
for (db, DB) โ enumerate(DBs), (f, file) โ enumerate(DB.files)
# perform cross-validation (using Eegle)
cv = crval(file; upperLimit = 1.2, bandPass=(8, 32), classes)
# store accuracy
MIacc[db][f] = cv.avgAcc
# print a summary of the cv results
println("\nDatabase ", DB.dbName,"-", DB.condition, ", File ", f,
": mean(sd) balanced accuracy ", round(cv.avgAcc*100, digits=2),
"% (ยฑ ", round(cv.stdAcc*100, digits=2), "%); ",
"p-value ", round(cv.p; digits = 4))
end
allMIacc = [round.(db; digits=3) for db โ MIacc]
MIw = [weightsDB(db.files)[1] for db โ DBs]; # get weights
MIw = [v ./= mean(v) for v โ w]; # normalize to unit mean
MIdbAcc = [mean(w.*acc) for (w, acc) โ zip(MIw, allMIacc)]